{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from __future__ import print_function\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import os\n",
    "import sys\n",
    "import tarfile\n",
    "from IPython.display import display, Image\n",
    "from scipy import ndimage\n",
    "from six.moves.urllib.request import urlretrieve\n",
    "from six.moves import cPickle as pickle\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "import h5py\n",
    "\n",
    "# The DigitStructFile is just a wrapper around the h5py data.  It basically references \n",
    "#    inf:              The input h5 matlab file\n",
    "#    digitStructName   The h5 ref to all the file names\n",
    "#    digitStructBbox   The h5 ref to all struc data\n",
    "class DigitStructFile:\n",
    "    def __init__(self, inf):\n",
    "        self.inf = h5py.File(inf, 'r')\n",
    "        self.digitStructName = self.inf['digitStruct']['name']\n",
    "        self.digitStructBbox = self.inf['digitStruct']['bbox']\n",
    "\n",
    "# getName returns the 'name' string for for the n(th) digitStruct. \n",
    "    def getName(self,n):\n",
    "        return ''.join([chr(c[0]) for c in self.inf[self.digitStructName[n][0]].value])\n",
    "\n",
    "# bboxHelper handles the coding difference when there is exactly one bbox or an array of bbox. \n",
    "    def bboxHelper(self,attr):\n",
    "        if (len(attr) > 1):\n",
    "            attr = [self.inf[attr.value[j].item()].value[0][0] for j in range(len(attr))]\n",
    "        else:\n",
    "            attr = [attr.value[0][0]]\n",
    "        return attr\n",
    "\n",
    "# getBbox returns a dict of data for the n(th) bbox. \n",
    "    def getBbox(self,n):\n",
    "        bbox = {}\n",
    "        bb = self.digitStructBbox[n].item()\n",
    "        bbox['height'] = self.bboxHelper(self.inf[bb][\"height\"])\n",
    "        bbox['label'] = self.bboxHelper(self.inf[bb][\"label\"])\n",
    "        bbox['left'] = self.bboxHelper(self.inf[bb][\"left\"])\n",
    "        bbox['top'] = self.bboxHelper(self.inf[bb][\"top\"])\n",
    "        bbox['width'] = self.bboxHelper(self.inf[bb][\"width\"])\n",
    "        return bbox\n",
    "\n",
    "    def getDigitStructure(self,n):\n",
    "        s = self.getBbox(n)\n",
    "        s['name']=self.getName(n)\n",
    "        return s\n",
    "\n",
    "# getAllDigitStructure returns all the digitStruct from the input file.     \n",
    "    def getAllDigitStructure(self):\n",
    "        return [self.getDigitStructure(i) for i in range(len(self.digitStructName))]\n",
    "\n",
    "# Return a restructured version of the dataset (one structure by boxed digit).\n",
    "#\n",
    "#   Return a list of such dicts :\n",
    "#      'filename' : filename of the samples\n",
    "#      'boxes' : list of such dicts (one by digit) :\n",
    "#          'label' : 1 to 9 corresponding digits. 10 for digit '0' in image.\n",
    "#          'left', 'top' : position of bounding box\n",
    "#          'width', 'height' : dimension of bounding box\n",
    "#\n",
    "# Note: We may turn this to a generator, if memory issues arise.\n",
    "    def getAllDigitStructure_ByDigit(self):\n",
    "        pictDat = self.getAllDigitStructure()\n",
    "        result = []\n",
    "        structCnt = 1\n",
    "        for i in range(len(pictDat)):\n",
    "            item = { 'filename' : pictDat[i][\"name\"] }\n",
    "            figures = []\n",
    "            for j in range(len(pictDat[i]['height'])):\n",
    "               figure = {}\n",
    "               figure['height'] = pictDat[i]['height'][j]\n",
    "               figure['label']  = pictDat[i]['label'][j]\n",
    "               figure['left']   = pictDat[i]['left'][j]\n",
    "               figure['top']    = pictDat[i]['top'][j]\n",
    "               figure['width']  = pictDat[i]['width'][j]\n",
    "               figures.append(figure)\n",
    "            structCnt = structCnt + 1\n",
    "            item['boxes'] = figures\n",
    "            result.append(item)\n",
    "        return result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "train_folders = 'train'\n",
    "test_folders = 'test'\n",
    "extra_folders = 'extra'\n",
    "\n",
    "fin = os.path.join(train_folders, 'digitStruct.mat')\n",
    "dsf = DigitStructFile(fin)\n",
    "train_data = dsf.getAllDigitStructure_ByDigit()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "fin = os.path.join(test_folders, 'digitStruct.mat')\n",
    "dsf = DigitStructFile(fin)\n",
    "test_data = dsf.getAllDigitStructure_ByDigit()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "fin = os.path.join(extra_folders, 'digitStruct.mat')\n",
    "dsf = DigitStructFile(fin)\n",
    "extra_data = dsf.getAllDigitStructure_ByDigit()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from PIL import Image\n",
    "train_imsize = np.ndarray([len(train_data),2])\n",
    "for i in np.arange(len(train_data)):\n",
    "    filename = train_data[i]['filename']\n",
    "    fullname = os.path.join(train_folders, filename)\n",
    "    im = Image.open(fullname)\n",
    "    train_imsize[i, :] = im.size[:]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "876.0 501.0\n",
      "25.0 12.0\n"
     ]
    }
   ],
   "source": [
    "print(np.amax(train_imsize[:,0]), np.amax(train_imsize[:,1]))\n",
    "print(np.amin(train_imsize[:,0]), np.amin(train_imsize[:,1]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1083.0 516.0\n",
      "31.0 13.0\n"
     ]
    }
   ],
   "source": [
    "test_imsize = np.ndarray([len(test_data),2])\n",
    "for i in np.arange(len(test_data)):\n",
    "    filename = test_data[i]['filename']\n",
    "    fullname = os.path.join(test_folders, filename)\n",
    "    im = Image.open(fullname)\n",
    "    test_imsize[i, :] = im.size[:]\n",
    "\n",
    "print(np.amax(test_imsize[:,0]), np.amax(test_imsize[:,1]))\n",
    "print(np.amin(test_imsize[:,0]), np.amin(test_imsize[:,1]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "ename": "IndexError",
     "evalue": "index 70000 is out of bounds for axis 0 with size 70000",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mIndexError\u001b[0m                                Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-10-793509c522e8>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m      4\u001b[0m     \u001b[0mfullname\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mos\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mpath\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mjoin\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mextra_folders\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfilename\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m      5\u001b[0m     \u001b[0mim\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mImage\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mopen\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mfullname\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m----> 6\u001b[0;31m     \u001b[0mextra_imsize\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mi\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m:\u001b[0m\u001b[1;33m]\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mim\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msize\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      7\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m      8\u001b[0m \u001b[0mprint\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mnp\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mamax\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mextra_imsize\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mnp\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mamax\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mextra_imsize\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mIndexError\u001b[0m: index 70000 is out of bounds for axis 0 with size 70000"
     ]
    }
   ],
   "source": [
    "extra_imsize = np.ndarray([70000,2])\n",
    "for i in np.arange(len(extra_data)):\n",
    "    filename = extra_data[i]['filename']\n",
    "    fullname = os.path.join(extra_folders, filename)\n",
    "    im = Image.open(fullname)\n",
    "    extra_imsize[i, :] = im.size[:]\n",
    "\n",
    "print(np.amax(extra_imsize[:,0]), np.amax(extra_imsize[:,1]))\n",
    "print(np.amin(extra_imsize[:,0]), np.amin(extra_imsize[:,1]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(array([  410,  4163, 15855, 30484], dtype=int64),)\n",
      "(array([9747], dtype=int64),)\n",
      "(array([ 1722,  2949,  6233, 12862], dtype=int64),)\n",
      "(array([  459,  5352,  7776, 11257, 12191], dtype=int64),)\n",
      "(array([32352], dtype=int64),)\n",
      "(array([19731, 25534, 56510], dtype=int64),)\n"
     ]
    }
   ],
   "source": [
    "print(np.where(train_imsize[:,0]==876))\n",
    "print(np.where(train_imsize[:,0]==25))\n",
    "print(np.where(test_imsize[:,0]==1083))\n",
    "print(np.where(test_imsize[:,0]==31))\n",
    "print(np.where(extra_imsize[:,0]==668))\n",
    "print(np.where(extra_imsize[:,0]==22))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "# 29929 image has more than 5 digits.\n",
      "(33402, 64, 64, 1) (33402, 6)\n",
      "(13068, 64, 64, 1) (13068, 6)\n"
     ]
    }
   ],
   "source": [
    "import PIL.Image as Image\n",
    "from skimage.color import rgb2gray\n",
    "def generate_dataset(data, folder):\n",
    "\n",
    "    dataset = np.ndarray([len(data),64, 64, 1], dtype='float32')\n",
    "    labels = np.ones([len(data),6], dtype=int) * 10\n",
    "    for i in np.arange(len(data)):\n",
    "        filename = data[i]['filename']\n",
    "        fullname = os.path.join(folder, filename)\n",
    "        im = Image.open(fullname)\n",
    "        boxes = data[i]['boxes']\n",
    "        num_digit = len(boxes)\n",
    "        labels[i,0] = num_digit\n",
    "        top = np.ndarray([num_digit], dtype='float32')\n",
    "        left = np.ndarray([num_digit], dtype='float32')\n",
    "        height = np.ndarray([num_digit], dtype='float32')\n",
    "        width = np.ndarray([num_digit], dtype='float32')\n",
    "        for j in np.arange(num_digit):\n",
    "            if j < 5: \n",
    "                labels[i,j+1] = boxes[j]['label']\n",
    "                if boxes[j]['label'] == 10: labels[i,j+1] = 0\n",
    "            else: print('#',i,'image has more than 5 digits.')\n",
    "            top[j] = boxes[j]['top']\n",
    "            left[j] = boxes[j]['left']\n",
    "            height[j] = boxes[j]['height']\n",
    "            width[j] = boxes[j]['width']\n",
    "        \n",
    "        im_top = np.amin(top)\n",
    "        im_left = np.amin(left)\n",
    "        im_height = np.amax(top) + height[np.argmax(top)] - im_top\n",
    "        im_width = np.amax(left) + width[np.argmax(left)] - im_left\n",
    "        \n",
    "        im_top = np.floor(im_top - 0.1 * im_height)\n",
    "        im_left = np.floor(im_left - 0.1 * im_width)\n",
    "        im_bottom = np.amin([np.ceil(im_top + 1.2 * im_height), im.size[1]])\n",
    "        im_right = np.amin([np.ceil(im_left + 1.2 * im_width), im.size[0]])\n",
    "\n",
    "        im = im.crop((im_left, im_top, im_right, im_bottom)).resize([64, 64], Image.ANTIALIAS)\n",
    "        img = np.array(im)\n",
    "        img = rgb2gray(img)\n",
    "        img = img.astype('float32')\n",
    "        img = np.expand_dims(img, axis=-1)\n",
    "        dataset[i] = img\n",
    "\n",
    "    return dataset, labels\n",
    "\n",
    "train_dataset, train_labels = generate_dataset(train_data, train_folders)\n",
    "print(train_dataset.shape, train_labels.shape)\n",
    "\n",
    "test_dataset, test_labels = generate_dataset(test_data, test_folders)\n",
    "print(test_dataset.shape, test_labels.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(70000, 64, 64, 1) (70000, 6)\n"
     ]
    }
   ],
   "source": [
    "extra_dataset, extra_labels = generate_dataset(extra_data[0:70000], extra_folders)\n",
    "print(extra_dataset.shape, extra_labels.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "train_dataset = np.delete(train_dataset, 29929, axis=0)\n",
    "train_labels = np.delete(train_labels, 29929, axis=0)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(33401, 64, 64, 1) (33401, 6)\n"
     ]
    }
   ],
   "source": [
    "print(train_dataset.shape, train_labels.shape)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(97722, 64, 64, 1) (97722, 6)\n",
      "(13068, 64, 64, 1) (13068, 6)\n",
      "(5679, 64, 64, 1) (5679, 6)\n"
     ]
    }
   ],
   "source": [
    "import random\n",
    "\n",
    "random.seed()\n",
    "\n",
    "n_labels = 10\n",
    "valid_index = []\n",
    "valid_index2 = []\n",
    "train_index = []\n",
    "train_index2 = []\n",
    "for i in np.arange(n_labels):\n",
    "    valid_index.extend(np.where(train_labels[:,1] == (i))[0][:400].tolist())\n",
    "    train_index.extend(np.where(train_labels[:,1] == (i))[0][400:].tolist())\n",
    "    valid_index2.extend(np.where(extra_labels[:,1] == (i))[0][:200].tolist())\n",
    "    train_index2.extend(np.where(extra_labels[:,1] == (i))[0][200:].tolist())\n",
    "\n",
    "random.shuffle(valid_index)\n",
    "random.shuffle(train_index)\n",
    "random.shuffle(valid_index2)\n",
    "random.shuffle(train_index2)\n",
    "\n",
    "valid_dataset = np.concatenate((extra_dataset[valid_index2,:,:,:], train_dataset[valid_index,:,:,:]), axis=0)\n",
    "valid_labels = np.concatenate((extra_labels[valid_index2,:], train_labels[valid_index,:]), axis=0)\n",
    "train_dataset_t = np.concatenate((extra_dataset[train_index2,:,:,:], train_dataset[train_index,:,:,:]), axis=0)\n",
    "train_labels_t = np.concatenate((extra_labels[train_index2,:], train_labels[train_index,:]), axis=0)\n",
    "\n",
    "print(train_dataset_t.shape, train_labels_t.shape)\n",
    "print(test_dataset.shape, test_labels.shape)\n",
    "print(valid_dataset.shape, valid_labels.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Compressed pickle size: 1603422862\n"
     ]
    }
   ],
   "source": [
    "pickle_file = 'SVHN_64x64x1_train.pickle'\n",
    "\n",
    "try:\n",
    "  f = open(pickle_file, 'wb')\n",
    "  save = {\n",
    "    'train_dataset': train_dataset_t,\n",
    "    'train_labels': train_labels_t,\n",
    "    }\n",
    "  pickle.dump(save, f, pickle.HIGHEST_PROTOCOL)\n",
    "  f.close()\n",
    "except Exception as e:\n",
    "    print('Unable to save data to', pickle_file, ':', e)\n",
    "    raise\n",
    "    \n",
    "statinfo = os.stat(pickle_file)\n",
    "print('Compressed pickle size:', statinfo.st_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Compressed pickle size: 214420026\n"
     ]
    }
   ],
   "source": [
    "pickle_file = 'SVHN_64x64x1_test.pickle'\n",
    "\n",
    "try:\n",
    "  f = open(pickle_file, 'wb')\n",
    "  save = {\n",
    "    'train_dataset': test_dataset,\n",
    "    'train_labels': test_labels,\n",
    "    }\n",
    "  pickle.dump(save, f, pickle.HIGHEST_PROTOCOL)\n",
    "  f.close()\n",
    "except Exception as e:\n",
    "    print('Unable to save data to', pickle_file, ':', e)\n",
    "    raise\n",
    "    \n",
    "statinfo = os.stat(pickle_file)\n",
    "print('Compressed pickle size:', statinfo.st_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Compressed pickle size: 93181314\n"
     ]
    }
   ],
   "source": [
    "pickle_file = 'SVHN_64x64x1_valid.pickle'\n",
    "\n",
    "try:\n",
    "  f = open(pickle_file, 'wb')\n",
    "  save = {\n",
    "    'train_dataset': valid_dataset,\n",
    "    'train_labels': valid_labels,\n",
    "    }\n",
    "  pickle.dump(save, f, pickle.HIGHEST_PROTOCOL)\n",
    "  f.close()\n",
    "except Exception as e:\n",
    "    print('Unable to save data to', pickle_file, ':', e)\n",
    "    raise\n",
    "    \n",
    "statinfo = os.stat(pickle_file)\n",
    "print('Compressed pickle size:', statinfo.st_size)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
